{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "1ubnMn66ga4Z"
   },
   "source": [
    "# A guided tour of Flax\n",
    "\n",
    "This notebook provides an guided tour of the features of Flax, starting from plain JAX, explaining the Flax module abstraction, and on to more advanced functionality."
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import jax\n",
    "from jax import numpy as jnp, random, jit, lax\n",
    "\n",
    "import flax\n",
    "from flax import nn, optim\n",
    "\n",
    "# init jax with some random compute. \n",
    "# JAX might complain about not having access to a GPU or TPU.\n",
    "_ = jnp.square(2.) "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "HMNMi98ejnll"
   },
   "source": [
    "## Intro to JAX\n",
    "\n",
    "[JAX](https://github.com/google/jax) is a numerical computation library which aims to replicate the numpy api.\n",
    "\n",
    "A few important things to know about JAX:\n",
    "\n",
    " - It is functional. This means no in-place ops and sliced assignments. Functions should not take inputs or produce outputs using global state.\n",
    " - JAX works best on functions where the bulk of the computation is in numpy calls, with Python control-flow generally limited to operate on array shapes and non-array data. See [JAX - The Sharp Bits](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html).\n",
    " - JAX can execute computations on CPUs, GPUs, and TPUs.\n",
    " - Functions using the jax.numpy api can be traced for automatic transformations\n",
    "\n",
    "   - **jit**: compile a function using XLA enabling fast execution\n",
    "   - **grad**: take the gradient of a function\n",
    "   - **vmap**: adds a batch dimension to a function\n",
    "   - **pmap**: split a computation across devices based on the first dimension of each input argument.\n",
    "\n",
    "\n",
    "## Neural Networks in JAX without Flax\n",
    "\n",
    "Before we dive into Flax what a typical neural networks component looks like when written in \"native\" JAX.\n",
    "\n",
    "We decompose a learnable linear layer into two parts: a initializer function which uses a JAX PRNGKey to generate a random kernel and bias and the apply function which computes the linear transformation using a set of parameters and some inputs."
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def dense_init(rng, in_features, out_features,\n",
    "               kernel_init=jax.nn.initializers.lecun_normal(),\n",
    "               bias_init=jax.nn.initializers.zeros):\n",
    "  k1, k2 = random.split(rng)\n",
    "  # init functions take a PRNGKey and a shape tuple and return ndarrays.\n",
    "  kernel = kernel_init(k1, (in_features, out_features))\n",
    "  bias = bias_init(k2, (out_features,))\n",
    "  return kernel, bias\n",
    "\n",
    "def dense_apply(params, inputs):\n",
    "  kernel, bias = params\n",
    "  return jnp.dot(inputs, kernel) + bias"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "hwwff8vUJ6t7"
   },
   "source": [
    "Functional programming without abstractions naturally results into somewhat verbose but very explicit code.\n",
    "\n",
    "Note how the random number generators and parameters are passed on explicitly to functions.\n",
    "JAX has no concept of variables so we cannot hide the parameters in variables somewhere.\n",
    "Similarly, there is no global random number generator which updates an internal seed."
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "params = dense_init(random.PRNGKey(0), in_features=2, out_features=4)\n",
    "print(params)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "uiEnIech21gc"
   },
   "source": [
    "Once we generated a set of parameters it is easy enough to apply them to some inputs."
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = jnp.ones((1, 2))\n",
    "dense_apply(params, x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "VPKvLaZl3BP1"
   },
   "source": [
    "Because everything is functional we can use the functional transformations that JAX provides to do useful things like taking gradients to optimize the model."
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def loss_fn(params, x):\n",
    "  y = dense_apply(params, x)\n",
    "  return jnp.mean(y ** 2)\n",
    "grad_fn = jax.grad(loss_fn) # by default jax.grad takes the gradient w.r.t. the first argument\n",
    "grad_fn(params, x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "__CrzjEfqE5L"
   },
   "source": [
    "## Simplifying Neural Networks in JAX: Flax Modules\n",
    "\n",
    "The core of Flax is the Module abstraction.\n",
    "Modules allow you to write parameterized functions just as if you were writing a normal numpy function with JAX.\n",
    "The Module api allows you to declare parameters and use them directly with the JAX api's.\n",
    "\n",
    "A few things to know about Modules:\n",
    "\n",
    "  1. A Module is created by defining a subclass of `flax.nn.Module` and implementing the `apply` method.\n",
    "  2. parameters are declared using `self.param(name, shape, init_func)` and return an initialized parameter value.\n",
    "  3. `Dense.init(rng, ...)` and `Dense.call(params, ...)` behave identically to the `dense_init` and `dense_apply` implemented earlier.\n",
    "\n",
    "Now let's try to do redefine the dense layer using Flax Modules."
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Dense(nn.Module):\n",
    "  \"\"\"A learned linear transformation.\"\"\"\n",
    "  def apply(self, x, features,\n",
    "            # init functions are of the form (PrngKey, shape) => init_value\n",
    "            kernel_init=jax.nn.initializers.lecun_normal(),\n",
    "            bias_init=jax.nn.initializers.zeros):\n",
    "    \"\"\"The main entry point to a Module. Represents the function that\n",
    "    given inputs and hyper-parameters computes an output. The actual parameters\n",
    "    (inputs and parameters) are user-controlled, and depend on the actual Module functionality.\n",
    "    For this example:\n",
    "    \n",
    "      * `x`: the input, an array of shape `(in_features)`.\n",
    "      * `features`: the number of outputs, an integer.\n",
    "      * `kernel_init`: the initializer for the kernel.\n",
    "      * `bias_init`: the initializer for the biases.\n",
    "    \"\"\"\n",
    "    in_features = x.shape[-1]\n",
    "    kernel_shape = (in_features, features)\n",
    "    kernel = self.param('kernel', kernel_shape, kernel_init)\n",
    "    bias = self.param('bias', (features,), bias_init)\n",
    "    return jnp.dot(x, kernel) + bias"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "y, params = Dense.init(random.PRNGKey(0), x, features=4)\n",
    "print(params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Dense.call(params, x, features=4)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "ednxR882N4kX"
   },
   "source": [
    "Note that both `init` and `call` end up using the same `apply` function. That is why we must specify all the inputs\n",
    "and parameters (the number of features) in both `init` and `call`. Often the parameters are the same for each\n",
    "call to `init` and `call`.\n",
    "For these situations, we can use `Module.partial` to apply these arguments. `partial` takes keyword arguments \n",
    "and returns a new Module for which the given arguments are already applied.\n",
    "It can be thought of as the equivalent of `functools.partial` for Modules."
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "module = Dense.partial(features=4) # Module with concrete hyper parameters, ready to be initialized\n",
    "_, params = module.init(random.PRNGKey(0), x)\n",
    "module.call(params, x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "3Q1oe0WhOkxC"
   },
   "source": [
    "### Composition\n",
    "\n",
    "Modules can be composed to form more complex Modules.\n",
    "\n",
    "Within a Module's `apply` function other modules behave just like functions.\n"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# same as flax.nn.relu\n",
    "def relu(x):\n",
    "  return jnp.maximum(0., x)\n",
    "\n",
    "class MLP(nn.Module):\n",
    "  \"\"\"Multi Layer Perceptron.\"\"\"\n",
    "  \n",
    "  def apply(self, x,\n",
    "            hidden_features,\n",
    "            output_features,\n",
    "            activation_fn):\n",
    "\n",
    "    z = Dense(x, hidden_features)\n",
    "    h = activation_fn(z)\n",
    "    y = Dense(h, output_features)\n",
    "    return y\n",
    "\n",
    "module = MLP.partial(hidden_features=8, output_features=4, activation_fn=relu)\n",
    "y, params = module.init(random.PRNGKey(0), x)\n",
    "print(y)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "o6CPJofni4vf"
   },
   "source": [
    "The `params` returned by `init` have a nested structure of lists, tuples, dicts and \n",
    "and other types that can contain arrays; we call such a structure a pytree.\n",
    "When we compose Modules as in our example, the `params` is a structure of nested dictionaries.\n",
    "We can use `jax.tree_map` to apply a function to each leaf of a pytree, e.g., to reveal\n",
    "the `params` structure of the MLP model (recall that `x.shape = (1, 2)`).\n"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "jax.tree_map(np.shape, params)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "OGyMhq1rmBRr"
   },
   "source": [
    "#### Module name\n",
    "By default Flax will use integers as keys for the parameters of sub Modules. By passing the `name` argument we can control the parameter structure and make it more meaningful."
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class NamedMLP(nn.Module):\n",
    "  def apply(self, x,\n",
    "            hidden_features,\n",
    "            output_features,\n",
    "            activation_fn):\n",
    "\n",
    "    z = Dense(x, hidden_features, name='hidden')\n",
    "    h = activation_fn(z)\n",
    "    y = Dense(h, output_features, name='out')\n",
    "    return y\n",
    "\n",
    "module = NamedMLP.partial(hidden_features=8, output_features=4, activation_fn=relu)\n",
    "_, params = module.init(random.PRNGKey(0), x)\n",
    "jax.tree_map(np.shape, params)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "_L5AOQ97mZAA"
   },
   "source": [
    "### Parameter sharing\n",
    "\n",
    "Sometimes a Module should be applied to multiple inputs with one set of parameters.\n",
    "We can make a Module for which parameters are shared between calls using `Module.shared`.\n",
    "Just like with `Module.partial` we can pass keyword arguments that are fixed for each call to the Module."
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SimpleRNN(nn.Module):\n",
    "  def apply(self, x, iterations=3):\n",
    "    dense = Dense.shared(\n",
    "        features=x.shape[-1],\n",
    "        kernel_init=jax.nn.initializers.orthogonal(),\n",
    "        name='cell')\n",
    "    ys = []\n",
    "    for i in range(iterations):\n",
    "      x = dense(x)\n",
    "      ys.append(x)\n",
    "    return ys"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "Bn396SL6RARD"
   },
   "source": [
    "we call the Dense layer named 'cell' 3 times but only one set of parameters shows up in the parameter structure due to weight sharing."
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ys, params = SimpleRNN.init(random.PRNGKey(0), x)\n",
    "print(ys)\n",
    "jax.tree_map(np.shape, params)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "lFUvB2IcZJpu"
   },
   "source": [
    "### Shape inference\n",
    "\n",
    "Previously we initialized the model by passing in some inputs.\n",
    "This is useful because it allows for Modules which automatically infer the shape of parameters based on inputs. It can also help catch errors in the model early, in the initialization phase of a program.\n",
    "\n",
    "Nonetheless, `Module.init` includes some unnecessary overhead because typically we are not interested in the actual output of the model during initialization. Therefore, we can use JAX built-in lazy evaluation to get the benefits of shape inference without doing any unnecessary compute.\n",
    "\n",
    "`Module.init_by_shape` returns only the shape and dtype of outputs but still creates fully initialized parameters. If you want to use initializers that (indirectly) depend on the values (not shape) of the inputs you should keep using `Module.init`."
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_spec = [(1, 2)] # the input specification is a list of shape tuples\n",
    "out_spec, params = SimpleRNN.init_by_shape(random.PRNGKey(0), input_spec)\n",
    "# TODO: uncomment this line  once __repr__ is fixed in jax\n",
    "# print('out_spec:', out_spec)\n",
    "jax.tree_map(np.shape, params)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "EAdFRvBwRran"
   },
   "source": [
    "### Model\n",
    "\n",
    "Module makes it easy to keep track of parameters inside a Model but so far it still required to explicitly keep track of parameter structure and the `init` & `call` functions.\n",
    "\n",
    "Model is a thin abstraction around a Module and a set of parameters.\n",
    "A Model instance is callable and functional (e.g., changing parameters requires a new model instance).\n",
    "\n",
    "Using `Module.init` or `Module.init_by_shape` will create a newly initialized set of parameters. Then you can wrap the module and the initialized parameters in a `Model` instance. "
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = jnp.ones((1, 2))\n",
    "module = Dense.partial(features=4)\n",
    "ys, initial_params = module.init(random.PRNGKey(0), x)\n",
    "model = nn.Model(module, initial_params)\n",
    "jax.tree_map(np.shape, model.params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.params"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "CDvnJ1VP_qfz"
   },
   "source": [
    "Parameters can be updated using the `Model.replace` method"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "biased_model = model.replace(params={'kernel': model.params['kernel'], 'bias': model.params['bias'] + 1.})\n",
    "biased_model.params"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "G8PAwSzbA29M"
   },
   "source": [
    "Model is registered as a JAX pytree container object which means that it can be passed to JAX transformations and `jax.tree_map`.\n",
    "\n",
    "For example we can take gradients w.r.t. a Model object. The returned Model object will contain the gradients corresponding to each parameter."
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def loss_fn(model):\n",
    "  y = model(x)\n",
    "  return jnp.mean(y ** 2)\n",
    "\n",
    "model_grad = jax.grad(loss_fn)(model)\n",
    "model_grad.params"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "yHxj5gVhAzLX"
   },
   "source": [
    "### State\n",
    "\n",
    "Flax allows stateful operations to happen within a limited scope.\n",
    "\n",
    "Stateful Modules are defined using the `Module.state` api. It returns a state object that has a property value that can be assigned to.\n",
    "\n",
    "A typical use of stateful Module is BatchNorm which maintains a moving average of batch statistics (mean, variance).\n",
    "During training the moving averages are updated such that they can be used during test time.\n",
    "\n"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# simplified version of nn.BatchNorm\n",
    "class BatchNorm(nn.Module):\n",
    "  def apply(self, x, red_axis=0, eps=1e-5,\n",
    "            momentum=0.99, training=False,\n",
    "            gamma_init=nn.initializers.ones,\n",
    "            beta_init=nn.initializers.zeros):\n",
    "\n",
    "    # compute the moments of the input\n",
    "    mean = x.mean(red_axis, keepdims=True)\n",
    "    var = jnp.square(x - mean).mean(red_axis, keepdims=True)\n",
    "\n",
    "    # define the state variables\n",
    "    ra_mean = self.state('mean', mean.shape, nn.initializers.zeros)\n",
    "    ra_var = self.state('var', var.shape, nn.initializers.ones)\n",
    "\n",
    "    if not self.is_initializing():  # during init we ignore the moving averages completely\n",
    "      if training:\n",
    "        # during training the moving averages are updated\n",
    "        alpha = 1. - momentum\n",
    "        ra_mean.value += alpha * (mean - ra_mean.value)\n",
    "        ra_var.value += alpha * (var - ra_var.value)\n",
    "      else:\n",
    "        # if we are not training we use the moving averages\n",
    "        mean = ra_mean.value\n",
    "        var = ra_var.value\n",
    "\n",
    "    # standardize the input\n",
    "    y = (x - mean) / jnp.sqrt(var + eps)\n",
    "\n",
    "    # learn the scale and bias of the output\n",
    "    gamma = self.param('gamma', mean.shape, gamma_init)\n",
    "    beta = self.param('beta', mean.shape, beta_init)\n",
    "    return gamma * y + beta"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "QKoNxlq84lHk"
   },
   "source": [
    "Stateful modules require special care when used. The `nn.stateful` context manager defines a scope in which stateful operations are allowed. Outside of this scope the state becomes immutable.\n",
    "\n",
    "The state is stored in a `nn.Collection` object which internally stores the state as a dictionary.\n",
    "\n",
    "`nn.stateful` takes a Collection containing the current state and returns a new Collection that contains the updated state. By default a new Collection will be created.\n",
    "\n",
    "When using `nn.stateful(state, mutable=False)` the state can be read but any updates will raise an error. This is often useful during test time to guarantee that test data does not affect the model."
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MyModel(nn.Module):\n",
    "\n",
    "  def apply(self, x, training=False):\n",
    "    x = Dense(x, features=4)\n",
    "    x = BatchNorm(x, training=training, momentum=0., name='batch_norm')\n",
    "    return x\n",
    "\n",
    "dist_a = lambda rng, shape: random.normal(rng, shape) * jnp.array([[1., 3.]])\n",
    "\n",
    "x_a = dist_a(random.PRNGKey(1), (1024, 2))\n",
    "print('std. deviation of input:', x_a.std(0))\n",
    "\n",
    "with nn.stateful() as init_state:\n",
    "  y, params = MyModel.init(random.PRNGKey(2), x_a)\n",
    "print('std. deviation of output (init):', y.std(0))\n",
    "\n",
    "with nn.stateful(init_state) as new_state:\n",
    "  y = MyModel.call(params, x_a, training=True)\n",
    "print('std. deviation of output (training):', y.std(0))\n",
    "\n",
    "with nn.stateful(new_state, mutable=False):\n",
    "  y = MyModel.call(params, x_a, training=False)\n",
    "print('std. deviation of output (testing):', y.std(0))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "M2DJaEVV6I67"
   },
   "source": [
    "The state can be inspected using `Collection.as_dict()`.\n",
    "\n",
    "Each Module has a path like key into the Collection (eg. '/some_module/nested_module/dense')."
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "init_state.as_dict()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "new_state.as_dict()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "NWqk1SBZ6iWD"
   },
   "source": [
    "The stateful mechanism forces the user to be explicit about stateful operations.\n",
    "\n",
    "One motivating example for this approach is to enforce that state is not updated at test time.\n",
    "\n",
    "Another benefit is that it is easier to replace the state when necessary.\n",
    "For example let say we want to apply this model on a second input distribution (b) with different statistics.\n"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dist_b = lambda rng, shape: random.normal(rng, shape) * jnp.array([[2., 5.]])\n",
    "\n",
    "x_b = dist_b(random.PRNGKey(1), (1024, 2))\n",
    "\n",
    "with nn.stateful(new_state, mutable=False):\n",
    "  y = MyModel.call(params, x_b, training=False)\n",
    "print(y.std(0)) # this will not be properly normalized!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "nWdH5jcr7d05"
   },
   "source": [
    "We can solve the skew in statistics by creating a separate state for this alternative input distribution."
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with nn.stateful(init_state) as state_b:\n",
    "  y = MyModel.call(params, x_b, training=True)\n",
    "print('std. deviation of output (training):', y.std(0))\n",
    "\n",
    "with nn.stateful(state_b, mutable=False):\n",
    "  y = MyModel.call(params, x_b, training=False)\n",
    "print('std. deviation of output (testing):', y.std(0))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "KEfIkD3uR8hh"
   },
   "source": [
    "## Optimizer\n",
    "\n",
    "The `flax.optim` package contains a simple api for optimizing a set of parameters using gradient descent algorithms.\n",
    "\n",
    "To illustrate the optimizer api let's first define a simple linear regression problem:"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rng = random.PRNGKey(0)\n",
    "rng, key1, key2 = random.split(rng, 3)\n",
    "n = 30\n",
    "x = jnp.linspace(-5., 5.)\n",
    "X = random.uniform(key1, (n,), minval=-5., maxval=5.)\n",
    "f = lambda x: 2. * x\n",
    "Y = f(X) + random.normal(key2, (n,))\n",
    "plt.plot(x, f(x))\n",
    "plt.scatter(X, Y)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The model is nothing more than a Dense module with a single feature"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LinearRegression(nn.Module):\n",
    "  def apply(self, x):\n",
    "    # add a singleton dimension to the input and remove the singleton feature dim from the output\n",
    "    return nn.Dense(x[..., None], features=1)[..., 0]\n",
    "\n",
    "rng, key = random.split(rng)\n",
    "_, initial_params = LinearRegression.init(key, X)\n",
    "model = nn.Model(LinearRegression, initial_params)\n",
    "\n",
    "# plot the data together with the line used to generate the data (blue) and the untrained model (orange)\n",
    "plt.plot(x, f(x))\n",
    "plt.plot(x, model(x))\n",
    "plt.scatter(X, Y)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will use gradient descent with momentum to fit the model to the data.\n",
    "Each optimizer inherits from the `flax.optim.OptimizerDef`.\n",
    "The `OptimizerDef` class provides an `init_state` and `apply_gradient` method which initialize and update the optimizer state, respectively.\n",
    "\n",
    "`OptimizerDef` does not actually maintain the state and optimized parameters. It it simply a collection of functions.\n",
    "\n",
    "When calling `OptimizerDef.create` the optimization target and the optimizer state (eg. gradient moving average) are wrapped together with the `OptimizerDef` in an instance of `Optimizer`.\n",
    "\n",
    "Optimizers can optimize any pytree of arrays (nested dicts, `Model` instances, etc), as long as the gradient w.r.t. is computable by `jax.grad`."
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer_def = optim.Momentum(learning_rate=0.1, beta=0.9)\n",
    "optimizer = optimizer_def.create(model)\n",
    "train_steps = 100\n",
    "\n",
    "def loss_fn(model):\n",
    "  Y_hat = model(X)\n",
    "  return jnp.square(Y - Y_hat).mean()\n",
    "\n",
    "for i in range(train_steps):\n",
    "  # optimizer.target is passed to the loss_fn\n",
    "  loss, grad = jax.value_and_grad(loss_fn)(optimizer.target)\n",
    "  # `apply_gradient` returns a new `Optimizer` instance with the updated target and optimizer state.\n",
    "  optimizer = optimizer.apply_gradient(grad)\n",
    "print('mean square error:', loss)\n",
    "\n",
    "trained_model = optimizer.target\n",
    "print(trained_model.params)\n",
    "plt.plot(x, f(x))\n",
    "plt.plot(x, trained_model(x))\n",
    "plt.scatter(X, Y)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Serialization\n",
    "\n",
    "The `flax.serialization` module provides utilities for extracting the state of optimizers, model, and other structures as a dictionary of arrays.\n",
    "It also provides an integration with message pack which can be used to efficiently serialize the state dictionary in a binary, cross-platform compatible format. \n",
    "\n",
    "The `flax.struct.dataclass` decorator can be used to create a Python dataclass which can be passed into jax transformations like `jit`, `grad`, and `tree_map`. It also integrates with the state dict api."
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Model and Optimizer are also Flax dataclasses.\n",
    "@flax.struct.dataclass\n",
    "class TrainState:\n",
    "  optimizer: optim.Optimizer\n",
    "  step: int\n",
    "\n",
    "state = TrainState(optimizer=optimizer, step=5)\n",
    "flax.serialization.to_state_dict(state)\n",
    ""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Flax dataclasses are immutable. Using the `replace` method a new instance can be created with a set of updated fields."
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "new_state = state.replace(step=6)\n",
    "print(state.step, new_state.step)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The `to_bytes` and `from_bytes` functions are used to convert an object to the message pack format and back."
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = flax.serialization.to_bytes(state)\n",
    "print('num bytes:', len(data))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "corrupted_state = jax.tree_map(lambda x: 0 * x, state)\n",
    "flax.serialization.to_state_dict(corrupted_state)\n",
    ""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Restore the state using the state dict using the serialized state dict stored in data\n",
    "restored_state = flax.serialization.from_bytes(corrupted_state, data)\n",
    "flax.serialization.to_state_dict(restored_state)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "1KCPh73QSsLr"
   },
   "source": [
    "## Advanced features"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "jJyqXS1RSvXg"
   },
   "source": [
    "### Selective Optimization\n",
    "\n",
    "Sometimes we wish to apply a different optimizer to a subset of the model parameters.\n",
    "To illustrate we will apply weight decay to bias value of the linear model that was trained before.\n",
    "\n",
    "`flax.optim.MultiOptimizer` takes in tuples of traversals and `OptimizerDef` instances. The traversal is responsible for selecting the subset of parameters that should be optimized. We will use `flax.optim.ModelParamTraversal` which allows you to filter parameters based on the path  (eg. '/hidden/dense/kernel')."
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "slope_opt_def = optim.Momentum(learning_rate=0.1)\n",
    "# by applying decay to the bias parameter it will end up being closer to zero than before\n",
    "bias_opt_def = optim.Momentum(learning_rate=0.1, weight_decay=10.)\n",
    "# select all kernel parameters\n",
    "slope_traversal = optim.ModelParamTraversal(lambda path, param: 'kernel' in path)\n",
    "# select all bias parameters\n",
    "bias_traversal = optim.ModelParamTraversal(lambda path, param: 'bias' in path)\n",
    "optimizer_def = optim.MultiOptimizer((slope_traversal, slope_opt_def), (bias_traversal, bias_opt_def))\n",
    "\n",
    "_, initial_params = LinearRegression.init(random.PRNGKey(0), X)\n",
    "model = nn.Model(LinearRegression, initial_params)\n",
    "optimizer = optimizer_def.create(model)\n",
    "\n",
    "train_steps = 100\n",
    "\n",
    "def loss_fn(model):\n",
    "  Y_hat = model(X)\n",
    "  return jnp.square(Y - Y_hat).mean()\n",
    "\n",
    "for i in range(train_steps):\n",
    "  loss, grad = jax.value_and_grad(loss_fn)(optimizer.target)\n",
    "  optimizer = optimizer.apply_gradient(grad)\n",
    "print('mean square error:', loss)\n",
    "\n",
    "trained_model = optimizer.target\n",
    "print(trained_model.params)\n",
    "plt.plot(x, f(x))\n",
    "plt.plot(x, trained_model(x))\n",
    "plt.scatter(X, Y)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "CHQhNahuS0ed"
   },
   "source": [
    "### Multi method modules\n",
    "\n",
    "\n"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MultiMethodModule(nn.Module):\n",
    "\n",
    "  def apply(self, x):\n",
    "    kernel = self.param('kernel', (), lambda _, shape: jnp.full(shape, 2.))\n",
    "    return x * kernel\n",
    "\n",
    "  @nn.module_method\n",
    "  def decode(self, x):\n",
    "    kernel = self.get_param('kernel')\n",
    "    return x * kernel\n",
    "\n",
    "x = 2. ** jnp.arange(5)\n",
    "y, initial_params = MultiMethodModule.init(random.PRNGKey(0), x)\n",
    "model = nn.Model(MultiMethodModule, initial_params)\n",
    "print('target:', x[1:], 'teacher forced decoding:', y[:-1])\n",
    "print('sequential decoding (one step):', model.decode(1.))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def body_fn(carry, _):\n",
    "  y = model.decode(carry)\n",
    "  new_carry = y  # feed output back\n",
    "  return new_carry, y\n",
    "\n",
    "carry, ys = lax.scan(body_fn, 1., (), length=4)\n",
    "print('carry:', carry)\n",
    "print('sequential decoding:', ys)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "zbA5RyHVTC_V"
   },
   "source": [
    "### Transforming sub module parameters"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def add_scale(module):\n",
    "  class ScaleWrapper(nn.Module):\n",
    "    \"\"\"Add a learnable scale to the kernel of a module.\"\"\"\n",
    "\n",
    "    def apply(self, *args, **kwargs):\n",
    "      def init_fn(rng, _):\n",
    "        _, params = module.init(rng, *args, **kwargs)\n",
    "        # here we could change the initial parameters of the wrapped module\n",
    "        return params\n",
    "      params = self.param('params', None, init_fn)\n",
    "      # here change transform parameters every call\n",
    "      assert 'kernel' in params\n",
    "      kernel = params['kernel']\n",
    "      features = kernel.shape[-1]\n",
    "      scale = self.param('scale', (features,), nn.initializers.ones)\n",
    "      scaled_kernel = kernel * scale\n",
    "      scaled_params = params.copy()\n",
    "      scaled_params['kernel'] = scaled_kernel\n",
    "\n",
    "      return module.call(scaled_params, *args, **kwargs)\n",
    "  return ScaleWrapper\n",
    "\n",
    "x = jnp.ones((1, 2))\n",
    "module = add_scale(Dense).partial(features=4)\n",
    "y, params = module.init(random.PRNGKey(0), x)\n",
    "params"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {},
 "nbformat": 4,
 "nbformat_minor": 4
}
